# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Run BERT on SQuAD 1.1 and SQuAD 2.0."""

from __future__ import absolute_import, division, print_function

import collections
import json
import math
import os
import random
import shutil
import time

# import horovod.tensorflow as hvd
import numpy as np
import six
import tensorflow as tf
from tensorflow.python.client import device_lib

import modeling
import optimization
import tokenization
from utils.create_squad_data import *
from utils.utils import LogEvalRunHook, LogTrainRunHook

# NPU modify start
# Add NPU package
from npu_bridge.estimator.npu.npu_config import NPURunConfig
from npu_bridge.estimator.npu.npu_estimator import NPUEstimator
from npu_bridge.estimator import npu_ops
from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig
from hccl.manage.api import get_rank_size
from hccl.manage.api import get_rank_id
# NPU modify end

flags = tf.flags
FLAGS = None

def extract_run_squad_flags():

  ## Required parameters
  flags.DEFINE_string(
      "bert_config_file", None,
      "The config json file corresponding to the pre-trained BERT model. "
      "This specifies the model architecture.")

  flags.DEFINE_string("vocab_file", None,
                      "The vocabulary file that the BERT model was trained on.")

  flags.DEFINE_string(
      "output_dir", None,
      "The output directory where the model checkpoints will be written.")

  ## Other parameters

  flags.DEFINE_string(
      "dllog_path", "/results/bert_dllog.json",
      "filename where dllogger writes to")

  flags.DEFINE_string("train_file", None,
                      "SQuAD json for training. E.g., train-v1.1.json")

  flags.DEFINE_string(
      "predict_file", None,
      "SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
  flags.DEFINE_string(
      "eval_script", None,
      "SQuAD evaluate.py file to compute f1 and exact_match E.g., evaluate-v1.1.py")

  flags.DEFINE_string(
      "init_checkpoint", None,
      "Initial checkpoint (usually from a pre-trained BERT model).")

  flags.DEFINE_bool(
      "do_lower_case", True,
      "Whether to lower case the input text. Should be True for uncased "
      "models and False for cased models.")

  flags.DEFINE_integer(
      "max_seq_length", 384,
      "The maximum total input sequence length after WordPiece tokenization. "
      "Sequences longer than this will be truncated, and sequences shorter "
      "than this will be padded.")

  flags.DEFINE_integer(
      "doc_stride", 128,
      "When splitting up a long document into chunks, how much stride to "
      "take between chunks.")

  flags.DEFINE_integer(
      "max_query_length", 64,
      "The maximum number of tokens for the question. Questions longer than "
      "this will be truncated to this length.")

  flags.DEFINE_bool("do_train", False, "Whether to run training.")

  flags.DEFINE_bool("do_predict", False, "Whether to run eval on the dev set.")

  flags.DEFINE_integer("train_batch_size", 8, "Total batch size for training.")

  flags.DEFINE_integer("predict_batch_size", 8,
                       "Total batch size for predictions.")

  flags.DEFINE_float("learning_rate", 5e-6, "The initial learning rate for Adam.")

  flags.DEFINE_bool("use_trt", False, "Whether to use TF-TRT")

  flags.DEFINE_bool("horovod", False, "Whether to use Horovod for multi-gpu runs")
  flags.DEFINE_float("num_train_epochs", 3.0,
                     "Total number of training epochs to perform.")

  flags.DEFINE_float(
      "warmup_proportion", 0.1,
      "Proportion of training to perform linear learning rate warmup for. "
      "E.g., 0.1 = 10% of training.")

  flags.DEFINE_integer("save_checkpoints_steps", 5000,
                       "How often to save the model checkpoint.")
  flags.DEFINE_integer("display_loss_steps", 10,
                       "How often to print loss from estimator")

  flags.DEFINE_integer("iterations_per_loop", 1000,
                       "How many steps to make in each estimator call.")

  flags.DEFINE_integer("num_accumulation_steps", 1,
                       "Number of accumulation steps before gradient update" 
                        "Global batch size = num_accumulation_steps * train_batch_size")

  flags.DEFINE_integer(
      "n_best_size", 20,
      "The total number of n-best predictions to generate in the "
      "nbest_predictions.json output file.")

  flags.DEFINE_integer(
      "max_answer_length", 30,
      "The maximum length of an answer that can be generated. This is needed "
      "because the start and end predictions are not conditioned on one another.")


  flags.DEFINE_bool(
      "verbose_logging", False,
      "If true, all of the warnings related to data processing will be printed. "
      "A number of warnings are expected for a normal SQuAD evaluation.")

  flags.DEFINE_bool(
      "version_2_with_negative", False,
      "If true, the SQuAD examples contain some that do not have an answer.")

  flags.DEFINE_float(
      "null_score_diff_threshold", 0.0,
      "If null_score - best_non_null is greater than the threshold predict null.")

  flags.DEFINE_bool("amp", True, "Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.")
  flags.DEFINE_bool("use_xla", False, "Whether to enable XLA JIT compilation.")
  flags.DEFINE_integer("num_eval_iterations", None,
                       "How many eval iterations to run - performs inference on subset")

  # Triton Specific flags
  flags.DEFINE_bool("export_triton", False, "Whether to export saved model or run inference with Triton")
  flags.DEFINE_string("triton_model_name", "bert", "exports to appropriate directory for Triton")
  flags.DEFINE_integer("triton_model_version", 1, "exports to appropriate directory for Triton")
  flags.DEFINE_string("triton_server_url", "localhost:8001", "exports to appropriate directory for Triton")
  flags.DEFINE_bool("triton_model_overwrite", False, "If True, will overwrite an existing directory with the specified 'model_name' and 'version_name'")
  flags.DEFINE_integer("triton_max_batch_size", 8, "Specifies the 'max_batch_size' in the Triton model config. See the Triton documentation for more info.")
  flags.DEFINE_float("triton_dyn_batching_delay", 0, "Determines the dynamic_batching queue delay in milliseconds(ms) for the Triton model config. Use '0' or '-1' to specify static batching. See the Triton documentation for more info.")
  flags.DEFINE_integer("triton_engine_count", 1, "Specifies the 'instance_group' count value in the Triton model config. See the Triton documentation for more info.")
  flags.mark_flag_as_required("vocab_file")
  flags.mark_flag_as_required("bert_config_file")
  flags.mark_flag_as_required("output_dir")

  # NPU modify start
  # Add NPU params
  flags.DEFINE_string('input_files_dir', None, 'tfrecord input file path')

  flags.DEFINE_bool('use_tfrecord', False, 'Whether to use tfrecord as dataset')

  flags.DEFINE_bool('npu_gather', False, 'Whether to use gather_npu whose backward propagation avoids IndexedSlices')

  flags.DEFINE_bool('npu_bert_debug', False, 'If True, dropout and shuffle is disabled.')

  flags.DEFINE_bool('npu_bert_fused_gelu', True, 'Whether to use npu defined gelu op')

  flags.DEFINE_bool('npu_bert_npu_dropout', True, 'Whether to use npu defined gelu op')

  flags.DEFINE_bool("npu_bert_clip_by_global_norm", False, "Use clip_by_global_norm if True, or use clip_by_norm for each gradient")

  flags.DEFINE_bool('npu_bert_use_tdt', True, 'Whether to use tdt as dataset')

  flags.DEFINE_integer("npu_bert_loss_scale", 0, "Whether to use loss scale, -1 is disable, 0 is dynamic loss scale, >=1 is static loss scale")

  flags.DEFINE_integer('init_loss_scale_value', 2**32, 'Initial loss scale value for loss scale optimizer')

  flags.DEFINE_bool('hcom_parallel', True, 'Whether to use parallel allreduce')

  flags.DEFINE_bool("distributed", False, "Whether to use multi-npu")

  flags.DEFINE_integer("num_train_steps", 0, "Number of training steps.")
  # NPU modify end

  return flags.FLAGS

# NPU modify start
# Add Hook func
class _LogSessionRunHook(tf.train.SessionRunHook):
    def __init__(self, global_batch_size, every_n_iter=100):
        self.global_batch_size = global_batch_size
        self.every_n_iter = every_n_iter
        self._timer = tf.train.SecondOrStepTimer(
            every_steps=every_n_iter, every_secs=None)

    def before_run(self, run_context):
        return tf.train.SessionRunArgs(
            fetches=['global_step:0', 'total_loss:0', 'learning_rate:0'])

    def after_run(self, run_context, run_values):
        global_step, total_loss, lr = run_values.results
        print_step = global_step
        sent_per_sec = 0

        elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
            global_step)
        if elapsed_time is not None:
            steps_per_sec = elapsed_steps / elapsed_time
            sent_per_sec = self.global_batch_size * steps_per_sec

        if print_step % self.every_n_iter == 0:
            print('Step = %6i Throughput = %11.1f Loss = %9.6f LR = %6.4e' %
                   (print_step, sent_per_sec, total_loss, lr), flush=True)
# NPU modify end

def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
                 use_one_hot_embeddings):
  """Creates a classification model."""
  model = modeling.BertModel(
      config=bert_config,
      is_training=is_training,
      input_ids=input_ids,
      input_mask=input_mask,
      token_type_ids=segment_ids,
      use_one_hot_embeddings=use_one_hot_embeddings,
      compute_type=tf.float16 if FLAGS.amp else tf.float32) # To be determined

  final_hidden = model.get_sequence_output()

  final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
  batch_size = final_hidden_shape[0]
  seq_length = final_hidden_shape[1]
  hidden_size = final_hidden_shape[2]

  output_weights = tf.get_variable(
      "cls/squad/output_weights", [2, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())

  final_hidden_matrix = tf.reshape(final_hidden,
                                   [batch_size * seq_length, hidden_size])
  logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
  logits = tf.nn.bias_add(logits, output_bias)

  logits = tf.reshape(logits, [batch_size, seq_length, 2])
  logits = tf.transpose(logits, [2, 0, 1])

  unstacked_logits = tf.unstack(logits, axis=0, name='unstack')

  (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])

  return (start_logits, end_logits)

def get_frozen_tftrt_model(bert_config, shape, use_one_hot_embeddings, init_checkpoint):
  tf_config = tf.compat.v1.ConfigProto()
  tf_config.gpu_options.allow_growth = True
  output_node_names = ['unstack']

  with tf.Session(config=tf_config) as tf_sess:
    input_ids = tf.placeholder(tf.int32, shape, 'input_ids')
    input_mask = tf.placeholder(tf.int32, shape, 'input_mask')
    segment_ids = tf.placeholder(tf.int32, shape, 'segment_ids')

    (start_logits, end_logits) = create_model(bert_config=bert_config,
                                              is_training=False,
                                              input_ids=input_ids,
                                              input_mask=input_mask,
                                              segment_ids=segment_ids,
                                              use_one_hot_embeddings=use_one_hot_embeddings)


    tvars = tf.trainable_variables()
    (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
    tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
    tf_sess.run(tf.global_variables_initializer())
    print("LOADED!")
    tf.compat.v1.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      else:
        init_string = ", *NOTTTTTTTTTTTTTTTTTTTTT"
        tf.compat.v1.logging.info("  name = %s, shape = %s%s", var.name, var.shape, init_string)

    frozen_graph = tf.graph_util.convert_variables_to_constants(tf_sess, 
            tf_sess.graph.as_graph_def(), output_node_names)

    num_nodes = len(frozen_graph.node)
    print('Converting graph using TensorFlow-TensorRT...')
    from tensorflow.python.compiler.tensorrt import trt_convert as trt
    converter = trt.TrtGraphConverter(
        input_graph_def=frozen_graph,
        nodes_blacklist=output_node_names,
        max_workspace_size_bytes=(4096 << 20) - 1000,
        precision_mode = "FP16" if FLAGS.amp else "FP32",
        minimum_segment_size=4,
        is_dynamic_op=True,
        maximum_cached_engines=1000
    )
    frozen_graph = converter.convert()

    print('Total node count before and after TF-TRT conversion:',
          num_nodes, '->', len(frozen_graph.node))
    print('TRT node count:',
          len([1 for n in frozen_graph.node if str(n.op) == 'TRTEngineOp']))
    
    with tf.io.gfile.GFile("frozen_modelTRT.pb", "wb") as f:
      f.write(frozen_graph.SerializeToString())      
        
  return frozen_graph


def model_fn_builder(bert_config, init_checkpoint, learning_rate,
                     num_train_steps, num_warmup_steps, amp=False,
                     use_one_hot_embeddings=False, distributed=False, rank_id=0): # NPU modify: Add HCCL params
  """Returns `model_fn` closure for Estimator."""

  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for Estimator."""
    if FLAGS.verbose_logging:
        tf.compat.v1.logging.info("*** Features ***")
        for name in sorted(features.keys()):
          tf.compat.v1.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    if not is_training and FLAGS.use_trt:
        trt_graph = get_frozen_tftrt_model(bert_config, input_ids.shape, use_one_hot_embeddings, init_checkpoint)
        (start_logits, end_logits) = tf.import_graph_def(trt_graph,
                input_map={'input_ids':input_ids, 'input_mask':input_mask, 'segment_ids':segment_ids},
                return_elements=['unstack:0', 'unstack:1'],
                name='')
        predictions = {
            "unique_ids": unique_ids,
            "start_logits": start_logits,
            "end_logits": end_logits,
        }
        output_spec = tf.estimator.EstimatorSpec(
            mode=mode, predictions=predictions)
        return output_spec

    (start_logits, end_logits) = create_model(
        bert_config=bert_config,
        is_training=is_training,
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        use_one_hot_embeddings=use_one_hot_embeddings)

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    if init_checkpoint and (not distributed or rank_id == 0): # NPU modify: HCCL replace Hvd
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      
      tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    if FLAGS.verbose_logging:
        tf.compat.v1.logging.info("**** Trainable Variables ****")
        for var in tvars:
          init_string = ""
          if var.name in initialized_variable_names:
            init_string = ", *INIT_FROM_CKPT*"
          tf.compat.v1.logging.info(" %d name = %s, shape = %s%s", 0 if not distributed else rank_id, var.name, var.shape, # NPU modify: HCCL replace Hvd
                          init_string)


    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      seq_length = modeling.get_shape_list(input_ids)[1]

      def compute_loss(logits, positions):
        one_hot_positions = tf.one_hot(
            positions, depth=seq_length, dtype=tf.float32)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        loss = -tf.reduce_mean(
            tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
        return loss

      start_positions = features["start_positions"]
      end_positions = features["end_positions"]

      start_loss = compute_loss(start_logits, start_positions)
      end_loss = compute_loss(end_logits, end_positions)

      total_loss = (start_loss + end_loss) / 2.0
      total_loss = tf.identity(total_loss, name='total_loss')

      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, None, True, amp, FLAGS.num_accumulation_steps) # To be determined

      output_spec = tf.estimator.EstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op)
    elif mode == tf.estimator.ModeKeys.PREDICT:
      '''
      dummy_op = tf.no_op()
      # Need to call mixed precision graph rewrite if fp16 to enable graph rewrite
      # To be determined
      
      if amp:
        loss_scaler = tf.train.experimental.FixedLossScale(1)
        dummy_op = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            optimization.LAMBOptimizer(learning_rate=0.0), loss_scaler)
      '''

      predictions = {
          "unique_ids": unique_ids,
          "start_logits": start_logits,
          "end_logits": end_logits,
      }
      output_spec = tf.estimator.EstimatorSpec(
          mode=mode, predictions=predictions)
    else:
      raise ValueError(
          "Only TRAIN and PREDICT modes are supported: %s" % (mode))

    return output_spec

  return model_fn


def input_fn_builder(input_file, batch_size, seq_length, is_training, drop_remainder,
                     distributed=False, rank_id=0, rank_size=1): # NPU modify: Add HCCL params
  """Creates an `input_fn` closure to be passed to Estimator."""

  name_to_features = {
      "unique_ids": tf.io.FixedLenFeature([], tf.int64),
      "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
      "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
      "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
  }

  if is_training:
    name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
    name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)

  def _decode_record(record, name_to_features):
    """Decodes a record to a TensorFlow example."""
    example = tf.parse_single_example(record, name_to_features)

    # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
    # So cast all int64 to int32.
    for name in list(example.keys()):
      t = example[name]
      if t.dtype == tf.int64:
        t = tf.to_int32(t)
      example[name] = t

    return example

  def input_fn():
    """The actual input function."""

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    if is_training:
        d = tf.data.TFRecordDataset(input_file, num_parallel_reads=4)
        # NPU modify start
        # HCCL replace Hvd
        if distributed:
            d = d.shard(rank_size, rank_id)
        # NPU modify end
        d = d.apply(tf.data.experimental.ignore_errors())
        if not FLAGS.npu_bert_debug: # NPU modify: Add debug
            d = d.shuffle(buffer_size=100)
        d = d.repeat()
    else:
        d = tf.data.TFRecordDataset(input_file)


    d = d.apply(
        tf.contrib.data.map_and_batch(
            lambda record: _decode_record(record, name_to_features),
            batch_size=batch_size,
            drop_remainder=drop_remainder))

    return d

  return input_fn



RawResult = collections.namedtuple("RawResult",
                                   ["unique_id", "start_logits", "end_logits"])


def get_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, 
  do_lower_case, version_2_with_negative, verbose_logging):
  """Get final predictions"""

  example_index_to_features = collections.defaultdict(list)
  for feature in all_features:
    example_index_to_features[feature.example_index].append(feature)

  unique_id_to_result = {}
  for result in all_results:
    unique_id_to_result[result.unique_id] = result

  _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
      "PrelimPrediction",
      ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

  all_predictions = collections.OrderedDict()
  all_nbest_json = collections.OrderedDict()
  scores_diff_json = collections.OrderedDict()

  for (example_index, example) in enumerate(all_examples):
    features = example_index_to_features[example_index]

    prelim_predictions = []
    # keep track of the minimum score of null start+end of position 0
    score_null = 1000000  # large and positive
    min_null_feature_index = 0  # the paragraph slice with min mull score
    null_start_logit = 0  # the start logit at the slice with min null score
    null_end_logit = 0  # the end logit at the slice with min null score
    for (feature_index, feature) in enumerate(features):
      result = unique_id_to_result[feature.unique_id]
      start_indexes = _get_best_indexes(result.start_logits, n_best_size)
      end_indexes = _get_best_indexes(result.end_logits, n_best_size)
      # if we could have irrelevant answers, get the min score of irrelevant
      if version_2_with_negative:
        feature_null_score = result.start_logits[0] + result.end_logits[0]
        if feature_null_score < score_null:
          score_null = feature_null_score
          min_null_feature_index = feature_index
          null_start_logit = result.start_logits[0]
          null_end_logit = result.end_logits[0]
      for start_index in start_indexes:
        for end_index in end_indexes:
          # We could hypothetically create invalid predictions, e.g., predict
          # that the start of the span is in the question. We throw out all
          # invalid predictions.
          if start_index >= len(feature.tokens):
            continue
          if end_index >= len(feature.tokens):
            continue
          if start_index not in feature.token_to_orig_map:
            continue
          if end_index not in feature.token_to_orig_map:
            continue
          if not feature.token_is_max_context.get(start_index, False):
            continue
          if end_index < start_index:
            continue
          length = end_index - start_index + 1
          if length > max_answer_length:
            continue
          prelim_predictions.append(
              _PrelimPrediction(
                  feature_index=feature_index,
                  start_index=start_index,
                  end_index=end_index,
                  start_logit=result.start_logits[start_index],
                  end_logit=result.end_logits[end_index]))

    if version_2_with_negative:
      prelim_predictions.append(
          _PrelimPrediction(
              feature_index=min_null_feature_index,
              start_index=0,
              end_index=0,
              start_logit=null_start_logit,
              end_logit=null_end_logit))
    prelim_predictions = sorted(
        prelim_predictions,
        key=lambda x: (x.start_logit + x.end_logit),
        reverse=True)

    _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "NbestPrediction", ["text", "start_logit", "end_logit"])

    seen_predictions = {}
    nbest = []
    for pred in prelim_predictions:
      if len(nbest) >= n_best_size:
        break
      feature = features[pred.feature_index]
      if pred.start_index > 0:  # this is a non-null prediction
        tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
        orig_doc_start = feature.token_to_orig_map[pred.start_index]
        orig_doc_end = feature.token_to_orig_map[pred.end_index]
        orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
        tok_text = " ".join(tok_tokens)

        # De-tokenize WordPieces that have been split off.
        tok_text = tok_text.replace(" ##", "")
        tok_text = tok_text.replace("##", "")

        # Clean whitespace
        tok_text = tok_text.strip()
        tok_text = " ".join(tok_text.split())
        orig_text = " ".join(orig_tokens)

        final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
        if final_text in seen_predictions:
          continue

        seen_predictions[final_text] = True
      else:
        final_text = ""
        seen_predictions[final_text] = True

      nbest.append(
          _NbestPrediction(
              text=final_text,
              start_logit=pred.start_logit,
              end_logit=pred.end_logit))

    # if we didn't inlude the empty option in the n-best, inlcude it
    if version_2_with_negative:
      if "" not in seen_predictions:
        nbest.append(
            _NbestPrediction(
                text="", start_logit=null_start_logit,
                end_logit=null_end_logit))
    # In very rare edge cases we could have no valid predictions. So we
    # just create a nonce prediction in this case to avoid failure.
    if not nbest:
      nbest.append(
          _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))

    assert len(nbest) >= 1

    total_scores = []
    best_non_null_entry = None
    for entry in nbest:
      total_scores.append(entry.start_logit + entry.end_logit)
      if not best_non_null_entry:
        if entry.text:
          best_non_null_entry = entry

    probs = _compute_softmax(total_scores)

    nbest_json = []
    for (i, entry) in enumerate(nbest):
      output = collections.OrderedDict()
      output["text"] = entry.text
      output["probability"] = probs[i]
      output["start_logit"] = entry.start_logit
      output["end_logit"] = entry.end_logit
      nbest_json.append(output)

    assert len(nbest_json) >= 1

    if not version_2_with_negative:
      all_predictions[example.qas_id] = nbest_json[0]["text"]
    else:
      # predict "" iff the null score - the score of best non-null > threshold
      score_diff = score_null - best_non_null_entry.start_logit - (
          best_non_null_entry.end_logit)
      scores_diff_json[example.qas_id] = score_diff
      if score_diff > FLAGS.null_score_diff_threshold:
        all_predictions[example.qas_id] = ""
      else:
        all_predictions[example.qas_id] = best_non_null_entry.text

    all_nbest_json[example.qas_id] = nbest_json
  return all_predictions, all_nbest_json, scores_diff_json

def write_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length, do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file,
                      version_2_with_negative, verbose_logging):
  """Write final predictions to the json file and log-odds of null if needed."""

  tf.compat.v1.logging.info("Writing predictions to: %s" % (output_prediction_file))
  tf.compat.v1.logging.info("Writing nbest to: %s" % (output_nbest_file))

  all_predictions, all_nbest_json, scores_diff_json = get_predictions(all_examples, all_features, 
    all_results, n_best_size, max_answer_length, do_lower_case, version_2_with_negative, verbose_logging)

  with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
    writer.write(json.dumps(all_predictions, indent=4) + "\n")

  with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
    writer.write(json.dumps(all_nbest_json, indent=4) + "\n")

  if version_2_with_negative:
    with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
      writer.write(json.dumps(scores_diff_json, indent=4) + "\n")


def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging):
  """Project the tokenized prediction back to the original text."""

  # When we created the data, we kept track of the alignment between original
  # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
  # now `orig_text` contains the span of our original text corresponding to the
  # span that we predicted.
  #
  # However, `orig_text` may contain extra characters that we don't want in
  # our prediction.
  #
  # For example, let's say:
  #   pred_text = steve smith
  #   orig_text = Steve Smith's
  #
  # We don't want to return `orig_text` because it contains the extra "'s".
  #
  # We don't want to return `pred_text` because it's already been normalized
  # (the SQuAD eval script also does punctuation stripping/lower casing but
  # our tokenizer does additional normalization like stripping accent
  # characters).
  #
  # What we really want to return is "Steve Smith".
  #
  # Therefore, we have to apply a semi-complicated alignment heruistic between
  # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
  # can fail in certain cases in which case we just return `orig_text`.

  def _strip_spaces(text):
    ns_chars = []
    ns_to_s_map = collections.OrderedDict()
    for (i, c) in enumerate(text):
      if c == " ":
        continue
      ns_to_s_map[len(ns_chars)] = i
      ns_chars.append(c)
    ns_text = "".join(ns_chars)
    return (ns_text, ns_to_s_map)

  # We first tokenize `orig_text`, strip whitespace from the result
  # and `pred_text`, and check if they are the same length. If they are
  # NOT the same length, the heuristic has failed. If they are the same
  # length, we assume the characters are one-to-one aligned.
  tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)

  tok_text = " ".join(tokenizer.tokenize(orig_text))

  start_position = tok_text.find(pred_text)
  if start_position == -1:
    if verbose_logging:
      tf.compat.v1.logging.info(
          "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
    return orig_text
  end_position = start_position + len(pred_text) - 1

  (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
  (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)

  if len(orig_ns_text) != len(tok_ns_text):
    if verbose_logging:
      tf.compat.v1.logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
                      orig_ns_text, tok_ns_text)
    return orig_text

  # We then project the characters in `pred_text` back to `orig_text` using
  # the character-to-character alignment.
  tok_s_to_ns_map = {}
  for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
    tok_s_to_ns_map[tok_index] = i

  orig_start_position = None
  if start_position in tok_s_to_ns_map:
    ns_start_position = tok_s_to_ns_map[start_position]
    if ns_start_position in orig_ns_to_s_map:
      orig_start_position = orig_ns_to_s_map[ns_start_position]

  if orig_start_position is None:
    if verbose_logging:
      tf.compat.v1.logging.info("Couldn't map start position")
    return orig_text

  orig_end_position = None
  if end_position in tok_s_to_ns_map:
    ns_end_position = tok_s_to_ns_map[end_position]
    if ns_end_position in orig_ns_to_s_map:
      orig_end_position = orig_ns_to_s_map[ns_end_position]

  if orig_end_position is None:
    if verbose_logging:
      tf.compat.v1.logging.info("Couldn't map end position")
    return orig_text

  output_text = orig_text[orig_start_position:(orig_end_position + 1)]
  return output_text


def _get_best_indexes(logits, n_best_size):
  """Get the n-best logits from a list."""
  index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

  best_indexes = []
  for i in range(len(index_and_score)):
    if i >= n_best_size:
      break
    best_indexes.append(index_and_score[i][0])
  return best_indexes


def _compute_softmax(scores):
  """Compute softmax probability over raw logits."""
  if not scores:
    return []

  max_score = None
  for score in scores:
    if max_score is None or score > max_score:
      max_score = score

  exp_scores = []
  total_sum = 0.0
  for score in scores:
    x = math.exp(score - max_score)
    exp_scores.append(x)
    total_sum += x

  probs = []
  for score in exp_scores:
    probs.append(score / total_sum)
  return probs



def validate_flags_or_throw(bert_config):
  """Validate the input FLAGS or throw an exception."""
  tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                FLAGS.init_checkpoint)

  if not FLAGS.do_train and not FLAGS.do_predict and not FLAGS.export_triton:
    raise ValueError("At least one of `do_train` or `do_predict` or `export_SavedModel` must be True.")

  if FLAGS.do_train:
    if not FLAGS.train_file:
      raise ValueError(
          "If `do_train` is True, then `train_file` must be specified.")
  if FLAGS.do_predict:
    if not FLAGS.predict_file:
      raise ValueError(
          "If `do_predict` is True, then `predict_file` must be specified.")

  if FLAGS.max_seq_length > bert_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, bert_config.max_position_embeddings))

  if FLAGS.max_seq_length <= FLAGS.max_query_length + 3:
    raise ValueError(
        "The max_seq_length (%d) must be greater than max_query_length "
        "(%d) + 3" % (FLAGS.max_seq_length, FLAGS.max_query_length))


def export_model(estimator, export_dir, init_checkpoint):
    """Exports a checkpoint in SavedModel format in a directory structure compatible with Triton."""
    def serving_input_fn():
        label_ids = tf.placeholder(tf.int32, [None,], name='unique_ids')
        input_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_ids')
        input_mask = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='input_mask')
        segment_ids = tf.placeholder(tf.int32, [None, FLAGS.max_seq_length], name='segment_ids')
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'unique_ids': label_ids,
            'input_ids': input_ids,
            'input_mask': input_mask,
            'segment_ids': segment_ids,
        })()
        return input_fn

    saved_dir = estimator.export_savedmodel(
        export_dir,
        serving_input_fn,
        assets_extra=None,
        as_text=False,
        checkpoint_path=init_checkpoint,
        strip_default_attrs=False)

    model_name = FLAGS.triton_model_name

    model_folder = export_dir + "/triton_models/" + model_name
    version_folder = model_folder + "/" + str(FLAGS.triton_model_version)
    final_model_folder = version_folder + "/model.savedmodel"

    if not os.path.exists(version_folder):
        os.makedirs(version_folder)
    
    if (not os.path.exists(final_model_folder)):
        os.rename(saved_dir, final_model_folder)
        print("Model saved to dir", final_model_folder)
    else:
        if (FLAGS.triton_model_overwrite):
            shutil.rmtree(final_model_folder)
            os.rename(saved_dir, final_model_folder)
            print("WARNING: Existing model was overwritten. Model dir: {}".format(final_model_folder))
        else:
            print("ERROR: Could not save Triton model. Folder already exists. Use '--triton_model_overwrite=True' if you would like to overwrite an existing model. Model dir: {}".format(final_model_folder))
            return

    # Now build the config for Triton. Check to make sure we can overwrite it, if it exists
    config_filename = os.path.join(model_folder, "config.pbtxt")

    if (os.path.exists(config_filename) and not FLAGS.triton_model_overwrite):
        print("ERROR: Could not save Triton model config. Config file already exists. Use '--triton_model_overwrite=True' if you would like to overwrite an existing model config. Model config: {}".format(config_filename))
        return
    
    config_template = r"""
name: "{model_name}"
platform: "tensorflow_savedmodel"
max_batch_size: {max_batch_size}
input [
    {{
        name: "unique_ids"
        data_type: TYPE_INT32
        dims: [ 1 ]
        reshape: {{ shape: [ ] }}
    }},
    {{
        name: "segment_ids"
        data_type: TYPE_INT32
        dims: {seq_length}
    }},
    {{
        name: "input_ids"
        data_type: TYPE_INT32
        dims: {seq_length}
    }},
    {{
        name: "input_mask"
        data_type: TYPE_INT32
        dims: {seq_length}
    }}
    ]
    output [
    {{
        name: "end_logits"
        data_type: TYPE_FP32
        dims: {seq_length}
    }},
    {{
        name: "start_logits"
        data_type: TYPE_FP32
        dims: {seq_length}
    }}
]
{dynamic_batching}
instance_group [
    {{
        count: {engine_count}
        kind: KIND_GPU
        gpus: [{gpu_list}]
    }}
]"""

    batching_str = ""
    max_batch_size = FLAGS.triton_max_batch_size

    if (FLAGS.triton_dyn_batching_delay > 0):

        # Use only full and half full batches
        pref_batch_size = [int(max_batch_size / 2.0), max_batch_size]

        batching_str = r"""
dynamic_batching {{
    preferred_batch_size: [{0}]
    max_queue_delay_microseconds: {1}
}}""".format(", ".join([str(x) for x in pref_batch_size]), int(FLAGS.triton_dyn_batching_delay * 1000.0))

    config_values = {
        "model_name": model_name,
        "max_batch_size": max_batch_size,
        "seq_length": FLAGS.max_seq_length,
        "dynamic_batching": batching_str,
        "gpu_list": ", ".join([x.name.split(":")[-1] for x in device_lib.list_local_devices() if x.device_type == "GPU"]),
        "engine_count": FLAGS.triton_engine_count
    }

    with open(model_folder + "/config.pbtxt", "w") as file:

        final_config_str = config_template.format_map(config_values)
        file.write(final_config_str)

# NPU modify start
# HCCL init func
def init_npu():
    npu_init = npu_ops.initialize_system()
    npu_shutdown = npu_ops.shutdown_system()
    config = tf.ConfigProto()
    custom_op =  config.graph_options.rewrite_options.custom_optimizers.add()
    custom_op.name =  "NpuOptimizer"
    custom_op.parameter_map["use_off_line"].b = True # Train on Ascend NPU
    if FLAGS.amp:
        custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")
    config.graph_options.rewrite_options.remapping = RewriterConfig.OFF  # Close Remap
    init_sess = tf.Session(config=config)

    return init_sess, npu_init, npu_shutdown
# NPU modify end

def main(_):
  # causes memory fragmentation for bert leading to OOM
  tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

  # NPU modify start
  npu_amp = 'allow_fp32_to_fp16'
  rank_id = 0
  rank_size = 1

  if FLAGS.distributed:
      init_sess, npu_init, _ = init_npu()
      init_sess.run(npu_init)
      rank_id = get_rank_id()
      rank_size = get_rank_size()

  if FLAGS.amp:
      amp = 'allow_mix_precision'
  # NPU modify end

  bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

  validate_flags_or_throw(bert_config)

  tf.io.gfile.makedirs(FLAGS.output_dir)

  tokenizer = tokenization.FullTokenizer(
      vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)

  master_process = True
  training_hooks = []
  global_batch_size = FLAGS.train_batch_size * FLAGS.num_accumulation_steps
  hvd_rank = 0

  config = tf.compat.v1.ConfigProto()
  learning_rate = FLAGS.learning_rate

  # NPU modify start
  # HCCL replace Hvd
  if FLAGS.distributed:
      tf.compat.v1.logging.info("Multi-NPU training with HCCL")
      tf.compat.v1.logging.info("RANK_SIZE = %d, RANK_ID = %d", rank_size, rank_id)
      global_batch_size = FLAGS.train_batch_size * rank_size
      learning_rate = learning_rate * rank_size
  if FLAGS.use_xla:
    config.graph_options.optimizer_options.global_jit_level = tf.compat.v1.OptimizerOptions.ON_1
    if FLAGS.amp:
        tf.enable_resource_variables()

  # NPU run config
  run_config = NPURunConfig(
      model_dir=FLAGS.output_dir,
      session_config=config,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=2,
      save_summary_steps=0,
      iterations_per_loop=FLAGS.iterations_per_loop,
      hcom_parallel=FLAGS.hcom_parallel,
      enable_data_pre_proc=FLAGS.npu_bert_use_tdt,
      precision_mode=amp)
  # NPU modify end

  if master_process:
      tf.compat.v1.logging.info("***** Configuaration *****")
      for key in FLAGS.__flags.keys():
          tf.compat.v1.logging.info('  {}: {}'.format(key, getattr(FLAGS, key)))
      tf.compat.v1.logging.info("**************************")

  train_examples = None
  num_train_steps = None
  num_warmup_steps = None
  # NPU modify start
  # Add Hook and HCCL replace Hvd
  training_hooks.append(_LogSessionRunHook(global_batch_size, FLAGS.iterations_per_loop))
  training_hooks.append(LogTrainRunHook(global_batch_size, rank_id, FLAGS.save_checkpoints_steps))
  # NPU modify end

  # Prepare Training Data
  if FLAGS.do_train:
    train_examples = read_squad_examples(
        input_file=FLAGS.train_file, is_training=True,
        version_2_with_negative=FLAGS.version_2_with_negative)
    num_train_steps = int(
        len(train_examples) / global_batch_size * FLAGS.num_train_epochs)
    num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    # Pre-shuffle the input to avoid having to make a very large shuffle
    # buffer in in the `input_fn`.
    rng = random.Random(12345)
    rng.shuffle(train_examples)

    start_index = 0 
    end_index = len(train_examples)
    tmp_filenames = [os.path.join(FLAGS.input_files_dir, "train.tf_record")] # NPU modify: change tfrecord save path

    # NPU modify start
    # HCCL replace Hvd
    if FLAGS.distributed:
        tmp_filenames = [os.path.join(FLAGS.input_files_dir, "train.tf_record{}".format(i)) for i in range(rank_size)]
        num_examples_per_rank = len(train_examples) // rank_size
        remainder = len(train_examples) % rank_size
        if rank_id < remainder:
            start_index = rank_id * (num_examples_per_rank+1)
            end_index = start_index + num_examples_per_rank + 1
        else:
            start_index = rank_id * num_examples_per_rank + remainder
            end_index = start_index + (num_examples_per_rank)

    if FLAGS.num_train_steps > 0:
        num_train_steps = FLAGS.num_train_steps
    # NPU modify end
    print("TMP FILES: ", tmp_filenames)

  # NPU modify start
  # Add model params of HCCL
  model_fn = model_fn_builder(
      bert_config=bert_config,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      amp=FLAGS.amp,
      distributed=FLAGS.distributed,
      rank_id=rank_id)

  # Replace NPUEstimator
  estimator = NPUEstimator(
      model_fn=model_fn,
      config=run_config)
  # NPU modify end

  if FLAGS.do_train:

    # We write to a temporary file to avoid storing very large constant tensors
    # in memory.
    if not FLAGS.use_tfrecord: # NPU modify: Use tfrecord to reduce predata time
        train_writer = FeatureWriter(
                filename=tmp_filenames[rank_id], # NPU modify: HCCL replace Hvd
            is_training=True)
        convert_examples_to_features(
            examples=train_examples[start_index:end_index],
            tokenizer=tokenizer,
            max_seq_length=FLAGS.max_seq_length,
            doc_stride=FLAGS.doc_stride,
            max_query_length=FLAGS.max_query_length,
            is_training=True,
            output_fn=train_writer.process_feature,
            verbose_logging=FLAGS.verbose_logging)
        train_writer.close()

        tf.compat.v1.logging.info("***** Running training *****")
        tf.compat.v1.logging.info("  Num orig examples = %d", end_index - start_index)
        tf.compat.v1.logging.info("  Num split examples = %d", train_writer.num_features)
        tf.compat.v1.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.compat.v1.logging.info("  Num steps = %d", num_train_steps)
        tf.compat.v1.logging.info("  LR = %f", learning_rate)
        del train_examples

    train_input_fn = input_fn_builder(
        input_file=tmp_filenames,
        batch_size=FLAGS.train_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True,
        distributed=FLAGS.distributed,
        rank_id=rank_id,
        rank_size=rank_size) # HCCL params replace Hvd

    train_start_time = time.time()
    estimator.train(input_fn=train_input_fn, hooks=training_hooks, max_steps=num_train_steps)
    train_time_elapsed = time.time() - train_start_time
    train_time_wo_overhead = training_hooks[-1].total_time
    if train_time_elapsed > 0:
        avg_sentences_per_second = num_train_steps * global_batch_size * 1.0 / train_time_elapsed
    else:
        avg_sentences_per_second = 0

    if train_time_wo_overhead > 0:
        ss_sentences_per_second = (num_train_steps - training_hooks[-1].skipped) * global_batch_size * 1.0 / train_time_wo_overhead
    else:
        ss_sentences_per_second = 0

    # NPU modify start
    # HCCL shutdown and init
    if FLAGS.distributed:
        init_sess, npu_init, npu_shutdown = init_npu()
        init_sess.run(npu_shutdown)
        init_sess.run(npu_init)
    # NPU modify end

    if master_process:
        tf.compat.v1.logging.info("-----------------------------")
        tf.compat.v1.logging.info("Total Training Time = %0.2f for Sentences = %d", train_time_elapsed,
                        num_train_steps * global_batch_size)
        tf.compat.v1.logging.info("Total Training Time W/O Overhead = %0.2f for Sentences = %d", train_time_wo_overhead,
                        (num_train_steps - training_hooks[-1].skipped) * global_batch_size)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) with overhead = %0.2f", avg_sentences_per_second)
        tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
        tf.compat.v1.logging.info("-----------------------------")


  if FLAGS.export_triton and master_process:
    export_model(estimator, FLAGS.output_dir, FLAGS.init_checkpoint)

  if FLAGS.do_predict and master_process:
    eval_examples = read_squad_examples(
        input_file=FLAGS.predict_file, is_training=False,
        version_2_with_negative=FLAGS.version_2_with_negative)

    # Perform evaluation on subset, useful for profiling
    if FLAGS.num_eval_iterations is not None:
        eval_examples = eval_examples[:FLAGS.num_eval_iterations*FLAGS.predict_batch_size]

    eval_writer = FeatureWriter(
        filename=os.path.join(FLAGS.output_dir, "eval.tf_record"),
        is_training=False)
    eval_features = []

    def append_feature(feature):
      eval_features.append(feature)
      eval_writer.process_feature(feature)

    convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_seq_length,
        doc_stride=FLAGS.doc_stride,
        max_query_length=FLAGS.max_query_length,
        is_training=False,
        output_fn=append_feature,
        verbose_logging=FLAGS.verbose_logging)
    eval_writer.close()

    tf.compat.v1.logging.info("***** Running predictions *****")
    tf.compat.v1.logging.info("  Num orig examples = %d", len(eval_examples))
    tf.compat.v1.logging.info("  Num split examples = %d", len(eval_features))
    tf.compat.v1.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

    predict_input_fn = input_fn_builder(
        input_file=eval_writer.filename,
        batch_size=FLAGS.predict_batch_size,
        seq_length=FLAGS.max_seq_length,
        is_training=False,
        drop_remainder=True)

    all_results = []
    eval_hooks = [LogEvalRunHook(FLAGS.predict_batch_size)]
    eval_start_time = time.time()
    for result in estimator.predict(
        predict_input_fn, yield_single_examples=True, hooks=eval_hooks):
      if len(all_results) % 1000 == 0:
        tf.compat.v1.logging.info("Processing example: %d" % (len(all_results)))
      unique_id = int(result["unique_ids"])
      start_logits = [float(x) for x in result["start_logits"].flat]
      end_logits = [float(x) for x in result["end_logits"].flat]
      all_results.append(
          RawResult(
              unique_id=unique_id,
              start_logits=start_logits,
              end_logits=end_logits))

    eval_time_elapsed = time.time() - eval_start_time

    time_list = eval_hooks[-1].time_list
    time_list.sort()
    # Removing outliers (init/warmup) in throughput computation.
    eval_time_wo_overhead = sum(time_list[:int(len(time_list) * 0.99)])
    num_sentences = (int(len(time_list) * 0.99)) * FLAGS.predict_batch_size

    avg = np.mean(time_list)
    cf_50 = max(time_list[:int(len(time_list) * 0.50)])
    cf_90 = max(time_list[:int(len(time_list) * 0.90)])
    cf_95 = max(time_list[:int(len(time_list) * 0.95)])
    cf_99 = max(time_list[:int(len(time_list) * 0.99)])
    cf_100 = max(time_list[:int(len(time_list) * 1)])
    ss_sentences_per_second = num_sentences * 1.0 / eval_time_wo_overhead

    # NPU modify start
    # HCCL shutdown
    if FLAGS.distributed:
        init_sess, _, npu_shutdown = init_npu()
        init_sess.run(npu_shutdown)
        init_sess.close()
    # NPU modify end
    tf.compat.v1.logging.info("-----------------------------")
    tf.compat.v1.logging.info("Total Inference Time = %0.2f for Sentences = %d", eval_time_elapsed,
                    eval_hooks[-1].count * FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Total Inference Time W/O Overhead = %0.2f for Sentences = %d", eval_time_wo_overhead,
                    num_sentences)
    tf.compat.v1.logging.info("Summary Inference Statistics")
    tf.compat.v1.logging.info("Batch size = %d", FLAGS.predict_batch_size)
    tf.compat.v1.logging.info("Sequence Length = %d", FLAGS.max_seq_length)
    tf.compat.v1.logging.info("Precision = %s", "fp16" if FLAGS.amp else "fp32")
    tf.compat.v1.logging.info("Latency Confidence Level 50 (ms) = %0.2f", cf_50 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 90 (ms) = %0.2f", cf_90 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 95 (ms) = %0.2f", cf_95 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 99 (ms) = %0.2f", cf_99 * 1000)
    tf.compat.v1.logging.info("Latency Confidence Level 100 (ms) = %0.2f", cf_100 * 1000)
    tf.compat.v1.logging.info("Latency Average (ms) = %0.2f", avg * 1000)
    tf.compat.v1.logging.info("Throughput Average (sentences/sec) = %0.2f", ss_sentences_per_second)
    tf.compat.v1.logging.info("-----------------------------")

    output_prediction_file = os.path.join(FLAGS.output_dir, "predictions.json")
    output_nbest_file = os.path.join(FLAGS.output_dir, "nbest_predictions.json")
    output_null_log_odds_file = os.path.join(FLAGS.output_dir, "null_odds.json")

    write_predictions(eval_examples, eval_features, all_results,
                      FLAGS.n_best_size, FLAGS.max_answer_length,
                      FLAGS.do_lower_case, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file,
                      FLAGS.version_2_with_negative, FLAGS.verbose_logging)

    if FLAGS.eval_script:
        import sys
        import subprocess
        eval_out = subprocess.check_output([sys.executable, FLAGS.eval_script,
                                          FLAGS.predict_file, output_prediction_file])
        scores = str(eval_out).strip()
        exact_match = float(scores.split(":")[1].split(",")[0])
        f1 = float(scores.split(":")[2].split(",")[0])
        print(str(eval_out))


if __name__ == "__main__":
  FLAGS = extract_run_squad_flags()
  tf.app.run()
